from __future__ import absolute_import
from __future__ import print_function
import os
import sys
import numpy as np

sys.path.append('../tool')
import toolkits
import src.utils as ut
import datetime
import pandas as pd

# ===========================================================
#        根据train_predict_score.py训练的每个说话人分数进行测试
# ===========================================================
import argparse

parser = argparse.ArgumentParser()
# set up training configuration.
parser.add_argument('--gpu', default='', type=str)
parser.add_argument('--resume', default='../model/gvlad_softmax/resnet34_vlad8_ghost2_bdim512_deploy/weights.h5',
                    type=str)
parser.add_argument('--batch_size', default=16, type=int)
#parser.add_argument('--data_path', default='../media/datasets/voxceleb1/wav', type=str)
#parser.add_argument('--data_path', default='../bear/audio_wav_test_bak', type=str)
parser.add_argument('--data_path', default='../audio_only9', type=str)
parser.add_argument('--test_path', default='../test_nosur', type=str)
# set up network configuration.
parser.add_argument('--net', default='resnet34s', choices=['resnet34s', 'resnet34l'], type=str)
parser.add_argument('--ghost_cluster', default=2, type=int)
parser.add_argument('--vlad_cluster', default=8, type=int)
parser.add_argument('--bottleneck_dim', default=512, type=int)
parser.add_argument('--aggregation_mode', default='gvlad', choices=['avg', 'vlad', 'gvlad'], type=str)
# set up learning rate, training loss and optimizer.
parser.add_argument('--loss', default='softmax', choices=['softmax', 'amsoftmax'], type=str)
parser.add_argument('--test_type', default='normal', choices=['normal', 'hard', 'extend'], type=str)

global args
args = parser.parse_args()


def main():
    # gpu configuration
    toolkits.initialize_GPU(args)

    # from src import model
    import model
    # ==================================
    #       Get Train/Val.
    # ==================================
    #print('==> calculating test({}) data lists...'.format(args.test_type))
   #写文件夹中所有wav与所有wav的txt文件
    test_list = []
    for test_wav in os.listdir(args.test_path):
        test_file = os.path.join(args.test_path, test_wav)
        for category in os.listdir(args.data_path):
            speaker_file = os.path.join(args.data_path, category)
            for category2 in os.listdir(speaker_file):
                speaker_wav1 = os.path.join(category, category2)
                test_list.append([1, test_file , speaker_wav1])
    #print(test_list)


    list1 = np.array([ i[1] for i in test_list])
    list2 = np.array([os.path.join(args.data_path, i[2]) for i in test_list])
    #print(list1,list2)
    total_list = np.concatenate((list1, list2))
    unique_list = np.unique(total_list)
    #print(unique_list)
    #print(len(unique_list))
    #unique_list_name = []
    #for i in range(len(unique_list)):
        #name_wav = unique_list[i].split('/')[3]
        #name = name_wav.split('1')[0]
        #unique_list_name.append(name)
    #print(unique_list_name)



    # ==================================
    #       Get Model
    # ==================================
    # construct the data generator.
    params = {'dim': (257, None, 1),
              'nfft': 512,
              'spec_len': 250,
              'win_length': 400,
              'hop_length': 160,
              'n_classes': 5994,
              'sampling_rate': 16000,
              'normalize': True,
              }

    network_eval = model.vggvox_resnet2d_icassp(input_dim=params['dim'],
                                                num_class=params['n_classes'],
                                                mode='eval', args=args)

    # ==> load pre-trained model ???
    if args.resume:
        # ==> get real_model from arguments input,
        # load the model if the imag_model == real_model.
        if os.path.isfile(args.resume):
            network_eval.load_weights(os.path.join(args.resume), by_name=True)
            print('==> successfully loading model {}.'.format(args.resume))
        else:
            raise IOError("==> no checkpoint found at '{}'".format(args.resume))
    else:
        raise IOError('==> please type in the model to load')

    print('==> start testing.')

    # The feature extraction process has to be done sample-by-sample,
    # because each sample is of different lengths.
    total_length = len(unique_list)
    feats, scores, labels = [], [], []
    for c, ID in enumerate(unique_list):
        if c % 50 == 0: print('Finish extracting features for {}/{}th wav.'.format(c, total_length))
        specs = ut.load_data(ID, win_length=params['win_length'], sr=params['sampling_rate'],
                             hop_length=params['hop_length'], n_fft=params['nfft'],
                             spec_len=params['spec_len'], mode='eval')
        specs = np.expand_dims(np.expand_dims(specs, 0), -1)

        v = network_eval.predict(specs)
        feats += [v]

    feats = np.array(feats)

    # ==> compute the pair-wise similarity.

    re_list = []
    result_list = []
    for c, (p1, p2) in enumerate(zip(list1, list2)):
        ind1 = np.where(unique_list == p1)[0][0]
        ind2 = np.where(unique_list == p2)[0][0]
        #print(ind1,ind2)
        v1 = feats[ind1, 0]
        v2 = feats[ind2, 0]

        scores += [np.sum(v1 * v2)]
        print(unique_list[ind1],unique_list[ind2])
        print('scores : {}'.format(scores[-1]))
        re_list.append([ind1, ind2, scores[-1]])
        test_name = unique_list[ind1].split('/')[2].split('1')[0]
        train_name = unique_list[ind2].split('/')[2]
        result_list.append([unique_list[ind1], test_name, train_name, scores[-1]])

    result_frame = pd.DataFrame(result_list, columns= ['test_wav', 'test_speaker', 'train_speaker',
                                                       'predict_score'])
    unique_wav_list = np.unique(result_frame['test_wav'])
    unique_train_speaker = np.unique(result_frame['train_speaker'])

    score_list =[]
    for i in unique_wav_list:
        test_wav_frame = result_frame[result_frame['test_wav']== i]
        test_wav_frame = test_wav_frame.reset_index(drop=True)
        test_speaker = test_wav_frame['test_speaker'][0]
        for j in unique_train_speaker :
            frame = test_wav_frame[test_wav_frame['train_speaker']==j]
            max_score = max(frame['predict_score'])
            score_list.append([i, test_speaker, j, max_score])
    print(score_list)

    #result_txt = np.loadtxt('../result/score_list.txt', str)
    score_frame = pd.DataFrame(score_list)
    score_frame.to_csv('../result/no_sur/predict_score_list.txt', header = 0)




if __name__ == "__main__":
    starttime = datetime.datetime.now()

    main()
    endtime = datetime.datetime.now()
    print((endtime - starttime).seconds, 's')
